import pickle
import time
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from jiquanquan.sample import match_algorithm


def trainParameter(fileName):
    """
    依据数据集统计π，A,B这三个参数
    :param fileName: 训练文本地址
    :return: 三个参数
    """
    # 定义一个查询字典，用于映射四种标记在数组中对应位置，方便查询
    # B:词语开头；M：一个词语的中间词；E：一个词语的结果；S：非词语，单个词
    statuDict = {'B': 0, 'M': 1, 'E': 2, 'S': 3}

    # 每个字只有四种状态，所以下面各类初始化中大小参数为4
    # 初始化PAI的一维数组，涉及四种状态
    PAI = np.zeros(4)
    # 初始化状态转移概率矩阵A，涉及四种状态分别到四种状态的转移，大小为4*4
    A = np.zeros((4, 4))
    # 初始化观测概率矩阵，涉及到四种状态到每个字的发射概率
    # 这里采用一个65536的空间保证对于所有的汉字都能找到对应的位置进行存储
    B = np.zeros((4, 65536))

    # 读取所有内容
    corpus = open(fileName, encoding='utf-8').readlines()
    # print(corpus)
    # 读取训练文本
    fr, test_data = train_test_split(corpus, test_size=0.2, random_state=10)
    # 文本中的每一行认为是一个训练样本,训练样本已经分词完毕，词语之间用间隔空开，统计时我们按照如下思路：
    """
        1.先将句子按照空格隔开，例如例句中5个词语，隔开后变成一个长度为5的列表，每个元素为一个词语
        2.对每个词语长度进行判断：
              如果为1认为该词语是S，即单个字
              如果为2则第一个是B，表开头，第二个为E，表结束
              如果大于2，则第一个为B，最后一个为E，中间全部标为M，表中间词
        3.统计PI：该句第一个字的词性对应的PI中位置加1
                  例如：PI = [0， 0， 0， 0]，当本行第一个字是B，即表示开头时，PI中B对应位置为0，
                       则PI = [1， 0， 0， 0]，全部统计结束后，按照计数值再除以总数得到概率
          统计A：对状态链中位置t和t-1的状态进行统计，在矩阵中相应位置加1，全部结束后生成概率
          统计B：对于每个字的状态以及字内容，生成状态到字的发射计数，全部结束后生成概率
    """
    for line in tqdm(fr):
        # 对单行句子去掉结尾的"\n"并按空格进行切割
        curLine = line.strip().split()
        # 对词性的标记放在该列表中
        wordLabel = []
        # 对分割后的每一个单词进行遍历
        for i in range(len(curLine)):
            # 如果单词长度为1，那么将该词标为"S"，即单个词
            if len(curLine[i]) == 1:
                label = "S"
            # 如果长度不为1，开头为B，最后为E，中间添加长度-2个M
            else:
                label = "B" + 'M' * (len(curLine[i]) - 2) + 'E'

            # 如果是单行开头的第一个字，PAI中对应位置加1
            if i == 0:
                PAI[statuDict[label[0]]] += 1

            # 对于单词中的每一个字，在生成的状态链中统计B
            for j in range(len(label)):
                # 遍历状态链中每一个状态，并找到对应的中文汉字，并在B对应的位置中加1
                # 因为是中文分词，使用ord(汉字)即可找到其对应汉字编码
                B[statuDict[label[j]]][ord(curLine[i][j])] += 1

            # 在整行的状态链中添加该单词的状态链
            wordLabel.extend(label)

        # 单行所有单词都结束后，统计状态转移A矩阵
        # 因为A涉及了前一个状态，所以需要等整条状态链都生成后再开始统计
        for i in range(1, len(wordLabel)):
            # 统计t时刻状态和t-1时刻状态的所有状态组合的出现次数
            A[statuDict[wordLabel[i - 1]]][statuDict[wordLabel[i]]] += 1

    # 将π，A,B转化为概率—归一化
    # 对PAI求和，概率生成中的分母
    sumPAI = np.sum(PAI)
    # 遍历PAI中的每一个元素，元素出现的次数/总次数等于概率
    for i in range(len(PAI)):
        # 为了防止结果下溢出，在概率上我们将其转换为log对数形式
        # 当单向概率为0的时候，log没有定义，因此需要单独判断,手动赋予一个极小值
        if PAI[i] == 0:
            PAI[i] = -3.14e+100
        else:
            PAI[i] = np.log(PAI[i] / sumPAI)

    # 对A求和，概率中的分母
    for i in range(len(A)):
        sumA = np.sum(A[i])
        for j in range(len(A[i])):
            if A[i][j] == 0:
                A[i][j] = -3.14e+100
            else:
                A[i][j] = np.log(A[i][j] / sumA)

    # 对B求和，概率中的分母
    for i in range(len(B)):
        sumB = np.sum(B[i])
        for j in range(len(B[i])):
            if B[i][j] == 0:
                B[i][j] = -3.14e+100
            else:
                B[i][j] = np.log(B[i][j] / sumB)

    return PAI, A, B


def loadArtical(fileName):
    """
    加载文章
    :param fileName: 文件路径
    :return: 文章内容
    """
    # 初始化文章列表
    artical = []
    # 打开文件
    with open(fileName, 'rb') as text:
        fr = pickle.load(text)
    # fr = open(fileName, encoding='utf-8')
    # 遍历读取文章每一行
    for line in fr:
        if line != '':
            # 去掉每一行末尾的"\n"
            curLine = line.strip()
            # 将当前行添加到文章列表中
            artical.append(curLine)

    return artical


def participle(artical, PAI, A, B):
    """
    基于维特比算法实现的分词
    :param artical: 要分词的文章列表
    :param PAI: 初始状态概率向量PAI
    :param A:状态转移矩阵
    :param B:观测概率矩阵
    :return:分词后的文章
    """
    # 初始化分词后的文章列表
    retArtical = []
    # 对文章按行读取
    for line in tqdm(artical):
        # 初始化δ，δ存放四种状态的概率值，因为状态链中每个状态都有四个概率值
        # 因此，大小为(文本长度*四种状态)
        delta = [[0 for i in range(4)] for i in range(len(line))]

        # 第一步：初始化
        for i in range(4):
            # 初始化δ状态链中第一个状态的四种状态概率
            delta[0][i] = PAI[i] + B[i][ord(line[0])]
        # 初始化ψ，初始时为0
        psi = [[0 for i in range(4)] for i in range(len(line))]

        # 第二步：递推,依次处理整条链
        for t in range(1, len(line)):
            # 对于链中的四种状态，求四种状态概率
            for i in range(4):
                # 初始化一个临时列表，用于存放四种概率
                tmpDelta = [0] * 4
                for j in range(4):
                    tmpDelta[j] = delta[t - 1][j] + A[j][i]

                # 找到最大的那个δ*a
                maxDelta = max(tmpDelta)
                # 记录最大值对应的状态
                maxDeltaIndex = tmpDelta.index(maxDelta)
                # 将找到的最大值*b放入delta中
                delta[t][i] = maxDelta + B[i][ord(line[t])]
                # 在ψ中记录对应的最大状态索引
                psi[t][i] = maxDeltaIndex
        # 建立一个状态链列表，开始生成状态链
        sequence = []

        # 第三步：终止；获取最后一个状态概率对应的索引
        i_opt = delta[len(line) - 1].index(max(delta[len(line) - 1]))
        # 在状态链中添加索引
        sequence.append(i_opt)

        # 第四步:最优路径回溯
        # 从后往前遍历整条链
        for t in range(len(line) - 1, 0, -1):
            # 不断从当前时刻t的ψ列表中读取到t-1的最优状态
            i_opt = psi[t][i_opt]
            # 将状态放入列表中
            sequence.append(i_opt)
        # 因为是从后往前将状态放入列表中，这里需要翻转一下
        sequence.reverse()

        # 基于预测出来的状态开始进行分词
        curLine = ''
        # 遍历该行的每一个字
        for i in range(len(line)):
            # 在字符串中加入该字
            curLine += line[i]
            # 如果该字是3：S->单个词  或  2:E->结尾词 ，并且不是这句话的最后一个字，则在该字后面加上分隔符
            if (sequence[i] == 3 or sequence[i] == 2) and i != (len(line) - 1):
                curLine += '|'
        # 在返回的列表中添加分词后的该行
        retArtical.append(curLine)
    # 返回分词结果
    return retArtical


if __name__ == '__main__':
    # # 记录开始时间
    # start = time.time()
    # # 依据数据集统计π，A,B这三个参数
    #
    # PAI, A, B = trainParameter("../docs/corpus/tempData/199801_processed_text.txt")
    # # 根据训练得到的模型做测试
    # artical = loadArtical("../docs/corpus/tempData/199801_test_text.txt")
    # # 打印原文
    # # print('---------------------------------------原文------------------------------------------------')
    # # for line in artical:
    # #     print(line)
    # # 进行分词
    # partiArtical = participle(artical, PAI, A, B)
    # # 打印分词结果
    # print('--------------------------------------分词-------------------------------------------------')
    # for line in partiArtical:
    #     print(line)
    #
    # end = time.time()
    # print('time span =', end - start, 's')

    # ----------------------上述代码来源：http://blog.17baishi.com/9969/--------------------------

    # 默认参数为14语料库
    match_algorithm.run_hmm_model()

    # 98语料库测试
    # processed_text_path = "../docs/corpus/tempData/199801_processed_text.txt"
    # test_text_path = "../docs/corpus/tempData/199801_test_text.txt"
    # hmm_res_path = '../docs/corpus/resultData/199801_hmm_result.txt'
    # match_algorithm.run_hmm_model(processed_text_path, test_text_path, hmm_res_path)